#include "streaming_asr_on_device.h"
// #include <android/log.h>
#include "dictionary.h"
#include "iniparser.h"
#include "rapidjson/document.h"
#include "rapidjson/rapidjson.h"
#include "rapidjson/stringbuffer.h"
#include "rapidjson/writer.h"
#include <exception>
#include <fstream>
#include <iostream>
#include <regex>
#include <unordered_map>
#include <utility>
#include <vector>

// #define LOGD(...)                                                              \
//   __android_log_print(ANDROID_LOG_DEBUG, "LOG_TAG", __VA_ARGS__);

int ReadJsonMvn(
    const char *filename,
    std::unordered_map<std::string, std::vector<float>> *mvn_stats) {
  using namespace rapidjson;
  if (!filename)
    return -1;
  mvn_stats->clear();
  std::string json_content = ReadFile(filename);
  if (json_content.empty()) {
    return -1;
  }

  rapidjson::Document document;
  document.Parse(json_content.c_str());

  for (rapidjson::Value::ConstMemberIterator iter = document.MemberBegin();
       iter != document.MemberEnd(); ++iter) {
    std::vector<float> ans;
    if (iter->value.IsArray()) {
      if (iter->name.GetString() == "mean_stat") {
        const rapidjson::Value &temp = document[iter->name.GetString()];
        std::vector<float> ans;
        for (SizeType i = 0; i < temp.Size(); ++i) {
          ans.push_back(temp[i].GetFloat());
        }
        (*mvn_stats)[iter->name.GetString()] = ans;
        // mvn_stats->insert(iter->name.GetString(), ans);
      }
      if (iter->name.GetString() == "var_stat") {
        const rapidjson::Value &temp = document[iter->name.GetString()];
        std::vector<float> ans;
        for (SizeType i = 0; i < temp.Size(); ++i) {
          ans.push_back(temp[i].GetFloat());
        }
        // *mvn_stats[iter->name.GetString()] = ans;
        (*mvn_stats)[iter->name.GetString()] = ans;
      }
      if (iter->name.GetString() == "frame_num") {
        const rapidjson::Value &temp = document[iter->name.GetString()];
        int frame_num = iter->value.GetInt();
        std::vector<float> ans;
        ans.push_back(frame_num);
        (*mvn_stats)[iter->name.GetString()] = ans;
      }
    }
  }
  return 0;
}

std::string ReadFile(const char *filename) {
  using namespace std;
  ifstream file;

  file.open(filename, ios_base::in); // 我的文件名称是my.json

  if (!file.is_open()) {
    cout << "!!!!!!! file " << filename << " is not open" << endl;
    return "";
  }

  stringstream ss{};
  ss << file.rdbuf();

  string str = ss.str();
  cout << str << endl;
  return str;
}

void StreamingAsrConfig::Read(const char *config_filename) {
  dictionary *init = iniparser_load(config_filename);
  int not_found = 0;
  N_c = iniparser_getint(init, "params:right_context", not_found);
  N_l = iniparser_getint(init, "params:left_context", not_found);
  conv_context = iniparser_getint(init, "params:conv_context", not_found);
  max_audio_second =
      iniparser_getint(init, "params:max_audio_second", not_found);
  sample_rate = iniparser_getint(init, "params:sample_rate", not_found);
  output_dim = iniparser_getint(init, "params:output_dim", not_found);
  input_dim = iniparser_getint(init, "params:input_dim", not_found);
  model_path =
      iniparser_getstring(init, "model:model_path", "quantized torch model");
  vocab_path =
      iniparser_getstring(init, "model:vocab_path", "training token.txt model");
  mvn_path = iniparser_getstring(init, "model:mvn_path", nullptr);
  iniparser_dump(init, stdout);
  iniparser_freedict(init);
}

StreamingAsrOnDevice::StreamingAsrOnDevice(const StreamingAsrConfig &asr_config)
    : asr_config_(asr_config), N_c_(asr_config.N_c), N_l_(asr_config.N_l),
      conv_context_(asr_config.conv_context),
      sample_rate_(asr_config.sample_rate),
      max_audio_second_(asr_config.max_audio_second),
      output_num_(asr_config.output_dim), input_dim_(asr_config.input_dim),
      inited_(false), acc_feats_num_(0), sample_num_(0), feats_offset_(0) {
  hpy_ = "";
  streaming_features_ = std::make_shared<OnlineStreaming>(asr_config);

  wenet::FeaturePipelineConfig feature_config;
  feature_pipeline_ = std::make_shared<wenet::FeaturePipeline>(feature_config);

  std::unordered_map<std::string, std::vector<float>> mvn_stats;
  if (asr_config.mvn_path != "" &&
      ReadJsonMvn(asr_config.mvn_path.c_str(), &mvn_stats) != 0) {
    std::cerr << "init cmvn stats failed" << std::endl;
    return;
  }

  std::vector<float> &stats_0 = mvn_stats["mean_stat"];
  std::vector<float> &stats_1 = mvn_stats["var_stat"];
  std::cout << "mean_stat size: " << stats_0.size() << std::endl;
  std::cout << "var_stat size: " << stats_1.size() << std::endl;
  if (stats_0.size() == 0 || stats_1.size() == 0) {
    return;
  }
  //   std::vector<double> stats_0 = {
  //       1.635988e+09, 1.628784e+09, 1.673592e+09, 1.781022e+09, 1.903415e+09,
  //       1.993753e+09, 2.058036e+09, 2.081383e+09, 2.075466e+09, 2.031254e+09,
  //       2.008084e+09, 1.985514e+09, 2.007573e+09, 2.036048e+09, 2.039828e+09,
  //       2.031848e+09, 2.006540e+09, 1.971878e+09, 1.972651e+09, 1.930832e+09,
  //       1.906898e+09, 1.947836e+09, 1.918207e+09, 1.942529e+09, 1.929324e+09,
  //       1.943099e+09, 1.928676e+09, 1.943651e+09, 1.942971e+09, 1.950406e+09,
  //       1.958658e+09, 1.966278e+09, 1.974502e+09, 1.987321e+09, 2.000777e+09,
  //       2.011650e+09, 2.018570e+09, 2.025228e+09, 2.031454e+09, 2.014550e+09,
  //       2.023486e+09, 2.008672e+09, 2.018519e+09, 2.019699e+09, 2.029152e+09,
  //       2.044156e+09, 2.060081e+09, 2.075150e+09, 2.087674e+09, 2.103228e+09,
  //       2.113018e+09, 2.121292e+09, 2.126583e+09, 2.129127e+09, 2.125261e+09,
  //       2.124375e+09, 2.115582e+09, 2.099215e+09, 2.082871e+09, 2.068075e+09,
  //       2.057114e+09, 2.043216e+09, 2.040407e+09, 2.035869e+09, 2.035494e+09,
  //       2.046859e+09, 2.057750e+09, 2.068387e+09, 2.082244e+09, 2.097615e+09,
  //       2.109707e+09, 2.123350e+09, 2.139785e+09, 2.150030e+09, 2.154173e+09,
  //       2.134630e+09, 2.024076e+09, 1.912613e+09, 1.891491e+09, 1.754062e+09,
  //       1.632984e+08};
  //   std::vector<double> stats_1 = {1.682693e+10, 1.692005e+10, 1.859171e+10,
  //                                  2.110234e+10, 2.408043e+10, 2.641564e+10,
  //                                  2.820877e+10, 2.903531e+10, 2.904122e+10,
  //                                  2.797332e+10, 2.740230e+10, 2.693345e+10,
  //                                  2.759470e+10, 2.843440e+10, 2.859457e+10,
  //                                  2.838242e+10, 2.770203e+10, 2.677333e+10,
  //                                  2.669307e+10, 2.560977e+10, 2.498661e+10,
  //                                  2.589439e+10, 2.512954e+10, 2.563965e+10,
  //                                  2.525876e+10, 2.551718e+10, 2.512856e+10,
  //                                  2.544896e+10, 2.542194e+10, 2.559642e+10,
  //                                  2.579659e+10, 2.598532e+10, 2.619212e+10,
  //                                  2.650786e+10, 2.683331e+10, 2.709272e+10,
  //                                  2.724945e+10, 2.738307e+10, 2.749903e+10,
  //                                  2.702450e+10, 2.718612e+10, 2.676739e+10,
  //                                  2.696262e+10, 2.695668e+10, 2.716165e+10,
  //                                  2.752254e+10, 2.791987e+10, 2.830301e+10,
  //                                  2.862730e+10, 2.902964e+10, 2.928459e+10,
  //                                  2.949718e+10, 2.962950e+10, 2.968500e+10,
  //                                  2.956141e+10, 2.951004e+10, 2.924025e+10,
  //                                  2.875389e+10, 2.827472e+10, 2.783702e+10,
  //                                  2.749052e+10, 2.709191e+10, 2.698696e+10,
  //                                  2.684167e+10, 2.683638e+10, 2.715407e+10,
  //                                  2.746005e+10, 2.776555e+10, 2.816104e+10,
  //                                  2.858861e+10, 2.892903e+10, 2.930606e+10,
  //                                  2.975593e+10, 3.002942e+10, 3.012331e+10,
  //                                  2.958312e+10, 2.679671e+10, 2.425787e+10,
  //                                  2.365945e+10, 2.003653e+10, 0};

  std::vector<std::vector<float>> norm(2, std::vector<float>(input_dim_, 0));
  for (int d = 0; d < input_dim_; d++) {
    float mean, offset, scale;
    float count = stats_0[input_dim_];
    mean = stats_0[d] / count;
    float var = (stats_1[d] / count) - mean * mean, floor = 1.0e-20;
    if (var < floor) {
      // std::cout << "WARNING : Flooring cepstral variance from " << var << "
      // to " << floor std::endl;
      var = floor;
    }
    scale = 1.0 / sqrt(var);
    if (scale != scale || 1 / scale == 0.0) {
      // std::cout << "WARNING : NaN or infinity in cepstral mean/variance
      // computation" << std::endl;
      return;
    }

    offset = -(mean * scale);
    norm[0][d] = offset;
    norm[1][d] = scale;
  }
  norm_.push_back(norm[0]);
  norm_.push_back(norm[1]);
}

int StreamingAsrOnDevice::init() {
  try {
    std::cout << "loading...... " << asr_config_.model_path << std::endl;
    asr_model_ = torch::jit::load(asr_config_.model_path);
    asr_model_.run_method("reset");
  } catch (std::exception &e) {
    LOGD("ERROR: Cannot open model file:%s\n", asr_config_.model_path.c_str());
    return 1;
  }
  std::cout << "loading....... " << asr_config_.vocab_path << std::endl;
  std::ifstream f_vocab(asr_config_.vocab_path, std::ios::in);
  if (!f_vocab.is_open()) {
    LOGD("ERROR: Cannot open vocab file:%s", asr_config_.vocab_path.c_str());
    return 2;
  }
  try {
    for (int i = 0; i < output_num_; i++) {
      int idx;
      std::string word;
      f_vocab >> word;
      f_vocab >> idx;
      vocab_[idx] = word;
    }
  } catch (std::exception &e) {
    LOGD("ERROR: Read dict failed:%s\n", asr_config_.vocab_path.c_str());
    return 3;
  }
  f_vocab.close(); // close

  inited_ = true;
  std::cout << "StreamingAsrOnDevice initialized." << std::endl;
  return 0;
}

void StreamingAsrOnDevice::reset() {
  if (inited_)
    asr_model_.run_method("reset");
  feature_pipeline_->Reset();
  hpy_ = "";
  feats_.clear();
  acc_feats_num_ = 0;
  feats_offset_ = 0;
  sample_num_ = 0;
  streaming_features_->Reset();
}

void StreamingAsrOnDevice::rawPcm2Float(const char *buf, size_t len,
                                        std::vector<float> &result) {
  result.clear();

  if (!buf || len < 0) {
    return;
  }
  result.resize(len / 2);
  const short *p = reinterpret_cast<const short *>(buf);
  for (size_t i = 0; i < len / 2; i++) {
    result[i] = p[i] / 32768.0;
  }
}

std::string StreamingAsrOnDevice::onRecieve(const float *wav, int len,
                                            bool is_last) {
  int status = 0;
  std::string message;
  std::string text;

  // feature extract
  std::vector<float> wav_vec(wav, wav + len);
  feature_pipeline_->AcceptWaveform(wav_vec);
  std::vector<float> feat;
  while (feature_pipeline_->ReadOne(&feat)) {
    // applyCmvn(feat);
    feats_.push_back(std::move(feat));
    acc_feats_num_++;
  }
  //   sample_num_ number accumulation
  sample_num_ += len;
  float audio_len = sample_num_ / 16000.0;
  if (audio_len > 15) {
    status = 1;
    message = "The audio length is too long (15s).";
    text = "";
    ResultJson res(status, message, text, sample_num_, audio_len);
    return res.GetJsonString();
  }

  // decode
  int start = 0, end = 0, input_len = 0, nc = 0;
  bool cnn_lookback = false, cnn_lookahead = false;
  if (is_last) {
    cnn_lookback = 0 < feats_offset_ - conv_context_;
    cnn_lookahead = false;
    start = std::max(0, feats_offset_ - conv_context_);
    end = acc_feats_num_;
    input_len = end - start;
    nc = input_len;
  } else if (acc_feats_num_ >= feats_offset_ + N_c_ + N_l_ + conv_context_) {
    // enough frames for lookback
    cnn_lookback = 0 < feats_offset_ - conv_context_;
    // lookahead always
    cnn_lookahead = true;
    start = std::max(0, feats_offset_ - conv_context_);
    end = feats_offset_ + N_c_ + N_l_ + conv_context_;
    input_len = end - start;
    nc = N_c_;
  } else {
    message = "Success";
    text = hpy_;
    ResultJson res(status, message, text, sample_num_, audio_len);
    return res.GetJsonString();
  }

  if (input_len == 0) {
    message = "Success";
    text = hpy_;
    ResultJson res(status, message, text, sample_num_, audio_len);
    return res.GetJsonString();
  }
  // do onrecieve
  torch::Tensor input_tensor = torch::ones({input_len, 80});
  for (int i = start; i < end; i++) {
    for (int d = 0; d < 80; d++) {
      input_tensor[i - start][d] = feats_[i][d];
    }
  }

  auto output = asr_model_
                    .run_method("on_recieve", input_tensor, nc, cnn_lookback,
                                cnn_lookahead)
                    .toTensor();
  for (int i = 0; i < output.size(0); i++) {
    int idx = output[i].item<int>();
    if (idx)
      hpy_ += vocab_[idx];
  }
  feats_offset_ += N_c_;

  if (is_last) {
    if (inited_)
      asr_model_.run_method("reset");
    feature_pipeline_->Reset();
  }

  message = "Success";
  text = hpy_;
  ResultJson res(status, message, text, sample_num_, audio_len);
  return res.GetJsonString();
}

std::string StreamingAsrOnDevice::ProcessBlock(const float *wav, int len,
                                               bool is_last) {
  int status = 0;
  std::string message;
  std::string text;

  // feature extract
  std::vector<float> wav_vec(wav, wav + len);
  feature_pipeline_->AcceptWaveform(wav_vec);
  std::vector<float> one_feat, temp_accum_feat;
  while (feature_pipeline_->ReadOne(&one_feat)) {
    // applyCmvn(feat);
    // feats_.push_back(std::move(one_feat));
    for (size_t i = 0; i < one_feat.size(); i++) {
      temp_accum_feat.push_back(one_feat[i]);
    }
    acc_feats_num_++;
  }
  std::cout << "finished computing features, size: " << temp_accum_feat.size()
            << std::endl;

  streaming_features_->ReceiveNewChunk(temp_accum_feat, is_last);
  //   sample_num_ number accumulation
  sample_num_ += len;
  float audio_len = sample_num_ / 16000.0;
  if (audio_len > 15) {
    status = 1;
    message = "The audio length is too long (15s).";
    text = "";
    ResultJson res(status, message, text, sample_num_, audio_len);
    return res.GetJsonString();
  }

  std::vector<float> block_features;
  streaming_features_->GetBlock(block_features);
  if (block_features.size() == 0) {
    message = "Success";
    text = hpy_;
    ResultJson res(status, message, text, sample_num_, audio_len);
    return res.GetJsonString();
  }

  torch::Tensor input_tensor =
      torch::ones({streaming_features_->CurrentBlockFrameSize(),
                   streaming_features_->GetDim()});
  for (int i = streaming_features_->GetTempStart();
       i < streaming_features_->GetTempEnd(); i++) {
    for (int d = 0; d < streaming_features_->GetDim(); d++) {
      input_tensor[i - streaming_features_->GetTempStart()][d] =
          streaming_features_->GetFeatureValue(i, d);
    }
  }

  auto output = asr_model_
                    .run_method("on_recieve", input_tensor,
                                streaming_features_->CurrentBlockSize(),
                                streaming_features_->LookBackward(),
                                streaming_features_->LookAhead())
                    .toTensor();
  for (int i = 0; i < output.size(0); i++) {
    int idx = output[i].item<int>();
    if (idx)
      hpy_ += " " + vocab_[idx];
  }
  feats_offset_ += N_c_;

  if (is_last) {
    if (inited_)
      asr_model_.run_method("reset");
    feature_pipeline_->Reset();
  }

  message = "Success";
  text = hpy_;
  ResultJson res(status, message, text, sample_num_, audio_len);
  return res.GetJsonString();
}

void StreamingAsrOnDevice::applyCmvn(std::vector<float> &feat) {
  if (norm_.size() == 0) {
    return;
  }

  for (int d = 0; d < input_dim_; d++) {
    feat[d] = feat[d] * norm_[1][d] + norm_[0][d];
  }
}
